package snowflake

import (
	"encoding/json"
	"errors"
	"fmt"
	"github.com/samuel/go-zookeeper/zk"
	"github.com/sony/sonyflake"
	"log"
	"math"
	"net"
	"strconv"
	"strings"
	"time"
)

var (
	dateFmt = "2006-01-02"
	timeFmt = "2006-01-02 15:04:05.000"

	sf   *sonyflake.Sonyflake
	conn *zk.Conn
	cfg  *initConfig

	startDate time.Time //雪花算法时间戳部分的开始时间

	nodeFullForeverPath string //当前节点的完整持久节点路径
	nodeFullTempPath    string //当前节点的完整临时节点路径

	workerID string //当前节点的雪花算法workerID

	lastReportTimestamp int64 //上次定时汇报时间戳（纳秒）
	lastGenTimestamp    int64 //上次生成id的时间戳（纳秒）
)

//Init 初始化基于ZK的分布式ID服务核心
func Init(zkServers []string, options ...Option) error {
	if len(zkServers) == 0 {
		return errors.New("zkServers不包含任何元素")
	}
	for _, v := range zkServers {
		if strings.TrimSpace(v) == "" {
			return errors.New("zkServers包含空元素")
		}
	}

	ip, err := getLocalIP()
	if err != nil {
		return errors.New("无法获取本机有效IP地址")
	}

	cfg = &initConfig{
		startDate:       "2020-01-01",
		foreverRootPath: "/dna_forever",
		tempRootPath:    "/dna_temp",
		nodeEndpoint:    fmt.Sprintf("%v:%v", ip.String(), 10000),
	}

	for _, v := range options {
		v(cfg)
	}

	startDate, err = time.Parse(dateFmt, cfg.startDate)
	if err != nil {
		return fmt.Errorf("startDate格式无效:%w", err)
	}
	nowDate, _ := time.Parse(dateFmt, time.Now().Format(dateFmt))
	if startDate.After(nowDate) {
		return errors.New("startDate晚于当前日期")
	}

	if strings.TrimSpace(cfg.foreverRootPath) == "" {
		return errors.New("foreverRootPath为空")
	}

	if strings.TrimSpace(cfg.tempRootPath) == "" {
		return errors.New("tempRootPath为空")
	}

	if strings.TrimSpace(cfg.nodeEndpoint) == "" {
		return errors.New("nodeEndpoint为空")
	}

	if err = connectToZK(zkServers); err != nil {
		return err
	}
	log.Println("zookeeper已连接")
	if err = checkRootPath(); err != nil {
		return err
	}
	log.Println("根目录已检查完毕")
	if err = getWorkerID(); err != nil {
		return err
	}

	if err = checkClusterTime(); err != nil {
		return err
	}
	log.Println("检查集群时间完毕")

	if err = initSF(); err != nil {
		return err
	}
	log.Println("初始化sonyflake完毕")

	if err = createTempNode(); err != nil {
		return err
	}
	log.Println("创建临时节点完毕")

	//定时汇报状态
	reportNodeState()
	log.Println("开始定时汇报状态")

	return nil
}

func GetID() (uint64, error) {
	now := getNowUnixNano()
	if lastGenTimestamp == 0 || now >= lastGenTimestamp {
		return genID(sf, now)
	}
	//发生了回拨，此刻时间小于上次发号时间
	offset := lastGenTimestamp - now
	if offset > 5*int64(time.Millisecond) { //1000000纳秒等于1毫秒
		return 0, errors.New("时间回拨超过5ms，不能发号")
	}

	//时间偏差小于5ms，等待两倍时间
	time.Sleep(2 * time.Duration(offset) * time.Nanosecond)
	now = getNowUnixNano()
	if now < lastGenTimestamp {
		//等待两倍偏差后仍然小于，无法恢复，不能继续发号
		return 0, errors.New("等待两倍偏差后仍然小于，无法恢复，不能继续发号")
	}
	return genID(sf, now)
}

func checkClusterTime() error {
	//无论当前节点是否注册过 均需要检查集群时间有效性
	//因为本机时间未回退不代表在整个集群中它的时间是正常的

	//todo: 由于临时节点只有会话过期才会删除，所以可能获得的某个节点时间比较早 造成不准确 后续改成通过rpc请求各节点
	//获取临时根节点下的所有子节点，这些节点代表短时间内有效的节点
	//然后比较集群间时间差异 只要差异在阈值范围内就算有效
	tmpChildren, _, err := conn.Children(cfg.tempRootPath)
	if err != nil {
		return fmt.Errorf("获取临时根节点%v下的子节点发生错误：%w", cfg.tempRootPath, err)
	}
	if len(tmpChildren) == 0 {
		log.Println("临时根节点下无任何节点，无需检查集群时间有效性")
		return nil
	}
	var timeSum uint64
	var validCount int
	for _, path := range tmpChildren {
		state, err := getState(cfg.tempRootPath + "/" + path)
		if err != nil {
			continue
		}
		//理论上会溢出但实际不会，即使到2100年，一百万个节点的毫秒时间戳相加也远远没到uint64最大值
		timeSum += uint64(state.LastTimestamp)
		validCount++
	}
	if validCount == 0 {
		return fmt.Errorf("无法正确获取其他有效节点的状态")
	}
	avg := timeSum / uint64(validCount)
	//abs( 系统时间-avg ) < 阈值
	abs := math.Abs(float64(uint64(time.Now().UnixNano()/1000000) - avg))
	if abs > float64(time.Second) { //差距大于1秒则认为偏移
		return fmt.Errorf("本机系统时间发生大步长偏移！")
	}
	log.Println("集群时间有效")
	return nil
}

func createTempNode() error {
	path := cfg.tempRootPath + "/" + cfg.nodeEndpoint
	state, err := createState()
	if err != nil {
		return err
	}
	_ = state
	//优先删除
	err = conn.Delete(path, -1)
	if err != nil && err != zk.ErrNoNode {
		return fmt.Errorf("删除临时节点失败:%w", err)
	}
	//临时节点主要用来表示节点存活
	_, err = conn.Create(path, state, zk.FlagEphemeral, zk.WorldACL(zk.PermAll))
	if err != nil {
		return fmt.Errorf("创建当前临时节点失败:%w", err)
	}
	nodeFullTempPath = path
	return nil
}

func initSF() error {
	sfSetting := sonyflake.Settings{
		StartTime: startDate,
		MachineID: func() (uint16, error) {
			parseUint, err := strconv.ParseUint(workerID, 10, 16)
			if err != nil {
				return 0, err
			}
			return uint16(parseUint), nil
		},
		CheckMachineID: func(u uint16) bool {
			return true
		},
	}
	sf = sonyflake.NewSonyflake(sfSetting)
	if sf == nil {
		return fmt.Errorf("初始化sonyflake失败")
	}
	return nil
}

func getWorkerID() error {
	//获取持久根节点下的子节点 用于判断当前节点是否注册过
	foreverChildren, _, err := conn.Children(cfg.foreverRootPath)
	if err != nil {
		return fmt.Errorf("获取%v子节点发生错误：%w", cfg.foreverRootPath, err)
	}
	for _, path := range foreverChildren {
		//子节点路径格式 类似 127.0.0.1:10000-0000000001
		tmp := strings.Split(path, "-")
		if tmp[0] == cfg.nodeEndpoint {
			workerID = tmp[1]
			nodeFullForeverPath = cfg.foreverRootPath + "/" + path
		}
	}

	//未注册过，进行节点注册
	if strings.TrimSpace(workerID) == "" {
		path := cfg.foreverRootPath + "/" + cfg.nodeEndpoint + "-"
		state, err := createState()
		if err != nil {
			return err
		}
		//在父持久节点下创建持久化的顺序节点 节点的序号就是workerID
		path, err = conn.Create(path, state, zk.FlagSequence, zk.WorldACL(zk.PermAll))
		if err != nil {
			return fmt.Errorf("创建当前节点失败:%w", err)
		}
		tmps := strings.Split(path, "-")
		workerID = tmps[1]
		log.Println("注册节点成功，workerID：", workerID)
		return nil
	}

	//注册过，判断时间是否发生回退
	state, err := getState(nodeFullForeverPath)
	if err != nil {
		return err
	}
	now := time.Now()
	if now.Before(state.LastTime) {
		return fmt.Errorf("发生时间回退：当前时间：%v，节点最后上报时间：%v", now.Format(timeFmt), state.LastTime.Format(timeFmt))
	}
	log.Println("workerID：", workerID, "时间正常未回退")
	return nil
}

func getState(nodeFullPath string) (*nodeState, error) {
	nodeData, _, err := conn.Get(nodeFullPath)
	if err != nil {
		return nil, fmt.Errorf("获取%v节点数据失败：%w", nodeFullPath, err)
	}
	var state nodeState
	err = json.Unmarshal(nodeData, &state)
	if err != nil {
		return nil, fmt.Errorf("反序列化节点状态数据失败：%w", err)
	}
	return &state, nil

}

func createState() ([]byte, error) {
	now := time.Now()
	state := nodeState{
		LastTimestamp: now.UnixNano() / 1000000, //1000000纳秒等于1毫秒
		LastTime:      now,
	}
	stateJson, err := json.Marshal(state)
	if err != nil {
		return nil, fmt.Errorf("序列化节点状态数据发生错误：%w", err)
	}
	return stateJson, nil
}

func checkRootPath() error {
	exists, _, err := conn.Exists(cfg.foreverRootPath)
	if err != nil {
		return fmt.Errorf("检查%v失败:%w", cfg.foreverRootPath, err)
	}
	if !exists {
		return fmt.Errorf("持久父节点:%v 不存在", cfg.foreverRootPath)
	}

	exists, _, err = conn.Exists(cfg.tempRootPath)
	if err != nil {
		return fmt.Errorf("检查%v失败:%w", cfg.tempRootPath, err)
	}
	if !exists {
		return fmt.Errorf("临时父节点:%v 不存在", cfg.tempRootPath)
	}
	return nil
}

func connectToZK(zkServers []string) error {
	//todo: sessionTimeout很重要 关系到临时节点的删除延迟 这块考虑重新设置一个更合理的值或公开
	var err error
	conn, _, err = zk.Connect(zkServers, time.Second*10000)
	if err != nil {
		return fmt.Errorf("连接zookeeper失败:%w", err)
	}
	return nil
}

func genID(sf *sonyflake.Sonyflake, now int64) (uint64, error) {
	id, err := sf.NextID()
	if err != nil {
		return 0, err
	}
	lastGenTimestamp = now
	return id, nil
}

func getNowUnixNano() int64 {
	return time.Now().UnixNano()
}

//reportNodeState 定时上报状态
func reportNodeState() {
	go func() {
		for {
			updateState()
			time.Sleep(500 * time.Millisecond)
		}
	}()
}

func updateState() {
	//当前节点时间发生回拨不应该上报状态
	now := time.Now().UnixNano()
	if lastReportTimestamp != 0 && now < lastReportTimestamp {
		return
	}
	_ = retry(func() error {
		state, err := createState()
		if err != nil {
			return err
		}
		_, err = conn.Set(nodeFullForeverPath, state, -1)
		if err != nil {
			return err
		}
		//todo: 目前采用临时节点的状态时间作为集群时间有效性判断依据 后续改成rpc请求后就不需要了
		_, err = conn.Set(nodeFullTempPath, state, -1)
		if err != nil {
			return err
		}
		lastReportTimestamp = now
		return nil
	}, 3, 250*time.Millisecond)
}

//nodeState 节点状态
type nodeState struct {
	//1970开始的毫秒
	LastTimestamp int64     `json:"last_timestamp"`
	LastTime      time.Time `json:"last_time"`
}

//initConfig 初始化配置信息
type initConfig struct {
	startDate       string
	foreverRootPath string
	tempRootPath    string
	nodeEndpoint    string
}

type Option func(cfg *initConfig)

//WithStartDate 指定雪花算法的开始时间 默认为2020-01-01
func WithStartDate(startDate string) Option {
	return func(cfg *initConfig) {
		cfg.startDate = startDate
	}
}

//WithNodeEndpoint 指定当前节点EP 不指定为第一个非回环的IPV4地址+10000端口
func WithNodeEndpoint(ep string) Option {
	return func(cfg *initConfig) {
		cfg.nodeEndpoint = ep
	}
}

//WithForeverRootPath 指定持久节点根节点 默认为/dna_forever
func WithForeverRootPath(foreverRootPath string) Option {
	return func(cfg *initConfig) {
		cfg.foreverRootPath = foreverRootPath
	}
}

//WithTempRootPath 指定临时节点根节点 默认为/dna_temp
func WithTempRootPath(tempRootPath string) Option {
	return func(cfg *initConfig) {
		cfg.tempRootPath = tempRootPath
	}
}

func retry(exec func() error, retryCount int, waitD time.Duration) error {
	var retErr error
	for i := 0; i < retryCount; i++ {
		err := exec()
		if err == nil {
			retErr = nil
			break
		}
		retErr = err
		time.Sleep(waitD)
	}
	return retErr
}

func getLocalIP() (net.IP, error) {
	ifaces, err := net.Interfaces()
	if err != nil {
		return nil, err
	}
	for _, iface := range ifaces {
		if iface.Flags&net.FlagUp == 0 {
			//不要down掉的接口
			continue
		}
		if iface.Flags&net.FlagLoopback != 0 {
			//不要环回接口
			continue
		}
		addrs, err := iface.Addrs()
		if err != nil {
			return nil, err
		}
		for _, addr := range addrs {
			ip := getIPByAddr(addr)
			if ip == nil {
				continue
			}
			return ip, nil
		}
	}
	return nil, errors.New("无法找到任何IP地址")
}

func getIPByAddr(addr net.Addr) net.IP {
	var ip net.IP
	switch v := addr.(type) {
	case *net.IPNet:
		ip = v.IP
	case *net.IPAddr:
		ip = v.IP
	}
	if ip == nil || ip.IsLoopback() {
		return nil
	}
	ip = ip.To4()
	if ip == nil {
		return nil //不是IPV4地址
	}
	return ip
}
